clear all;close all;clc;
addpath baseline;
addpath tools;
% matlabpool('open',12);
%% Rayleigh channel, narrow-band

%% basic simulation settings
M = 32;              % the number of BS antennas
N = 256;             % the number of RIS elements
K = 8;               % the number of users
L = 2;               % the number of transmitter antennas 
                     % in the large-timescale channel estimation

%% system power settings
Tx_power_BS = 10;
Tx_power_UE = 1;

%% large-scale fading settings
D_ref = 1;
sigma2_ref = 0.01;
D_G = 20;
D_d = 30;
D_r = 20;
alpha_G = 2.1;
alpha_d = 4.2;
alpha_r = 2.2;
sigma2_G = sigma2_ref*(D_ref/D_G)^alpha_G;
sigma2_r = sigma2_ref*(D_ref/D_r)^alpha_r;
sigma2_d = sigma2_ref*(D_ref/D_d)^alpha_d;

%%  Monte-Carlo simulation settings
N_trial = 1000;
SNR_S_list = [-30:5:30];
Noise_var_list = Tx_power_UE*(sigma2_r*sigma2_G)*10.^(-SNR_S_list/10);
N_noise_list = numel(Noise_var_list);

%% allocate some memory to store the results
Norm_G = zeros(N_trial,1);
Norm_H_d = zeros(N_trial,1);
Norm_H_r = zeros(N_trial,1);
Norm_H_c = zeros(N_trial,1);
SR_perfectCSI = zeros(N_noise_list,N_trial);

SE_H_d_mine = zeros(N_noise_list,N_trial);
SE_H_r_mine = zeros(N_noise_list,N_trial);
SE_H_c_mine = zeros(N_noise_list,N_trial);
SR_mine = zeros(N_noise_list,N_trial);

SE_H_d_mine_mp = zeros(N_noise_list,N_trial);
SE_H_r_mine_mp = zeros(N_noise_list,N_trial);
SE_H_c_mine_mp = zeros(N_noise_list,N_trial);
SR_mine_mp = zeros(N_noise_list,N_trial);

SE_H_d_MU = zeros(N_noise_list,N_trial);
SE_H_c_MU = zeros(N_noise_list,N_trial);
SR_MU = zeros(N_noise_list,N_trial);

SE_H_d_CS = zeros(N_noise_list,N_trial);
SE_H_c_CS = zeros(N_noise_list,N_trial);
SR_CS = zeros(N_noise_list,N_trial);

SE_H_d_MVU = zeros(N_noise_list,N_trial);
SE_H_c_MVU = zeros(N_noise_list,N_trial);
SR_MVU = zeros(N_noise_list,N_trial);

%% begin the Monte-Carlo simulations
for trial_idx = 1:N_trial
    %% generate the channels
    fprintf('==========================================================================================================================================\n');
    G = (normrnd(0,1,M,N) + 1i*normrnd(0,1,M,N))*sqrt(sigma2_G/2);     % the channel matrix between the BS and the RIS
    H_d = (normrnd(0,1,M,K) + 1i*normrnd(0,1,M,K))*sqrt(sigma2_d/2);   % the channel matrix between the BS and the UE (direct link)
    H_r = (normrnd(0,1,N,K) + 1i*normrnd(0,1,N,K))*sqrt(sigma2_r/2);   % the channel matrix between the RIS and the UE (reflected link)
    S = (normrnd(0,1,M,M) + 1i*normrnd(0,1,M,M))*sqrt(1/2)*10;         % environmental reflection channel
    H_c = zeros(M, N, K);
    for i = 1:M
        for j = 1:N
            for k = 1:K
                H_c(i,j,k) = G(i,j)*H_r(j,k);
            end
        end
    end
    
    Norm_G(trial_idx) = norm(G,'fro')^2;
    Norm_H_d(trial_idx) = norm(H_d,'fro')^2;
    Norm_H_r(trial_idx) = norm(H_r,'fro')^2;
    Norm_H_c(trial_idx) = sum(sum(sum(abs(H_c).^2)));
    
    for noise_idx = 1:N_noise_list
        if (noise_idx>1)
            fprintf('++--------------+-------------+-----------------+-------------------+-------------------+-------------------+---------------------------++\n');
        end
        noise_power = Noise_var_list(noise_idx);
        
        %% calculate the sum rate with perfect CSI
        [~,D,THETA] = ACE_RIS_1bit(1000,H_c,H_d,Tx_power_BS/noise_power);
        H_precoding = ((H_r.')*diag(THETA)*(G.')+(H_d.'))*(D);
        SR = 0;
        for k = 1:K
            SR = SR + log(1+Tx_power_BS/K*(abs(H_precoding(k,k))^2)/(Tx_power_BS/K*(sum(abs(H_precoding(k,:)).^2)-abs(H_precoding(k,k))^2)+noise_power));
        end
        SR_perfectCSI(noise_idx,trial_idx) = SR;

    %% %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% %%  
    %% %%%%%%%% THE PROPOSED CHANNEL ESTIMATION FRAMEWORK %%%%%%%% %%
    %% %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% %%
    
        %% Phase 1: estimating the channel between the BS and the RIS
        options_Phase1 = struct;
        options_Phase1.M = M;
        options_Phase1.N = N;
        options_Phase1.SI_plus_noise_var = noise_power*3;
        options_Phase1.Tx_power = Tx_power_BS;
        options_Phase1.L = L;
        options_Phase1.Tx_id = 1:L;
        options_Phase1.sigma2_G = sigma2_G;
        options_Phase1.I_max = 10; % max number of outer iterations in the coordinate descent in phase-1
        
        N_plus = 2^ceil(log2(N+1));
        Ha = hadamard(N_plus);
        aa = [1 randperm(N)+1];
        Phi = Ha(aa,aa);
        
        S = (normrnd(0,1,M,M) + 1i*normrnd(0,1,M,M))*sqrt(1/2)*10;
        Y = ReflectionPilotTransmission( G, S*sqrt(options_Phase1.SI_plus_noise_var), Phi, options_Phase1 ); % pilot transmission, reflection, and reception
        G_est = ReflectionChannelEstimation( Y, Phi, options_Phase1, G );     % run Algorithm 1 to estimate the channel
 
        %% Phase 2: estimating the channels between the BS and the UEs, and the channels between the RIS and the UEs
        options_Phase2 = struct;
        options_Phase2.M = M;
        options_Phase2.N = N;
        options_Phase2.K = K;
        options_Phase2.noise_var = noise_power;
        options_Phase2.Tx_power = Tx_power_UE;
        options_Phase2.N_slot = ceil(N/M)+1;
        [Y, Phi, P, options_Phase2] = UEUplinkPilotTransmission( G, H_d, H_r, options_Phase2 );
        [H_est, ~, ~, se_r, ~ ] = UEUplinkChannelEstimation( Y, G_est, Phi, P, options_Phase2, H_d, H_r );

        %% evaluate the estimation error
        H_d_est = H_est(1:M,:);
        se_d = sum(sum(sum(abs(H_d-H_d_est).^2)));
        nse_d = se_d/sum(sum(sum(abs(H_d).^2)));
        H_c_est = zeros(M, N, K);
        for i = 1:M
            for j = 1:N
                for k = 1:K
                    H_c_est(i,j,k) = G_est(i,j)*H_est(M+j,k);
                end
            end
        end
        se_c = sum(sum(sum(abs(H_c-H_c_est).^2)));
        nse_c = se_c/sum(sum(sum(abs(H_c).^2)));
        se_s = se_c + se_d;
        nse_s = se_s/(sum(sum(sum(abs(H_c).^2)))+sum(sum(sum(abs(H_d).^2))));

        [~,D,THETA] = ACE_RIS_1bit(1000,H_c_est,H_d_est,Tx_power_BS/noise_power);
        H_precoding = ((H_r.')*diag(THETA)*(G.')+(H_d.'))*(D);
        SR = 0;
        for k = 1:K
            SR = SR + log(1+Tx_power_BS/K*(abs(H_precoding(k,k))^2)/(Tx_power_BS/K*(sum(abs(H_precoding(k,:)).^2)-abs(H_precoding(k,k))^2)+noise_power));
        end
        
        %% store the results and display some information
        SE_H_d_mine(noise_idx,trial_idx) = se_d;
        SE_H_r_mine(noise_idx,trial_idx) = se_r;
        SE_H_c_mine(noise_idx,trial_idx) = se_c;
        SR_mine(noise_idx,trial_idx) = SR;
        
        fprintf('|| trial = %-4d | SNR = %-2d dB |  method = MINE  | NSE_d = %-6.2f dB | NSE_c = %-6.2f dB | NSE_s = %-6.2f dB | Sum Rate = %-5.2f bit/s/Hz ||\n',trial_idx,SNR_S_list(noise_idx),10*log10(nse_d),10*log10(nse_c),10*log10(nse_s),SR);

    %% %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% %%  
    %% %%%%%%% Wang Zhaorui's MU CHANNEL ESTIMATION SCHEME %%%%%%% %%
    %% %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% %%
        options_MU = struct;
        options_MU.M = M;
        options_MU.N = N;
        options_MU.K = K;
        options_MU.noise_var = noise_power;
        options_MU.Tx_power = Tx_power_UE;

        [akt,phint] = WangZ_PilotAllocation(options_MU);
        [Y_1,Y_2,Y_3] = WangZ_PilotTransmission(G,H_d,H_r,options_MU,akt,phint);
        [H_d_est,H_ref,H_r_est] = WangZ_ChannelEstimation(Y_1,Y_2,Y_3,options_MU,akt,phint);
        %% evaluate the estimation error
        se_d = sum(sum(sum(abs(H_d-H_d_est).^2)));
        nse_d = se_d/sum(sum(sum(abs(H_d).^2)));

        H_c_est = zeros(M, N, K);
        H_c_est(:, :, 1) = H_ref;
        for k = 2:K
            H_c_est(:, :, k) = H_ref*diag(H_r_est(:,k-1));
        end
        se_c = sum(sum(sum(abs(H_c-H_c_est).^2)));
        nse_c = se_c/sum(sum(sum(abs(H_c).^2)));
        se_s = se_c + se_d;
        nse_s = se_s/(sum(sum(sum(abs(H_c).^2)))+sum(sum(sum(abs(H_d).^2))));

        [~,D,THETA] = ACE_RIS_1bit(1000,H_c_est,H_d_est,Tx_power_BS/noise_power);
        H_precoding = ((H_r.')*diag(THETA)*(G.')+(H_d.'))*(D);
        SR = 0;
        for k = 1:K
            SR = SR + log(1+Tx_power_BS/K*(abs(H_precoding(k,k))^2)/(Tx_power_BS/K*(sum(abs(H_precoding(k,:)).^2)-abs(H_precoding(k,k))^2)+noise_power));
        end
        
        %% store the results and display some information
        SE_H_d_MU(noise_idx,trial_idx) = se_d;
        SE_H_c_MU(noise_idx,trial_idx) = se_c;
        SR_MU(noise_idx,trial_idx) = SR;

        fprintf('|| trial = %-4d | SNR = %-2d dB |  method = WangZ | NSE_d = %-6.2f dB | NSE_c = %-6.2f dB | NSE_s = %-6.2f dB | Sum Rate = %-5.2f bit/s/Hz ||\n',trial_idx,SNR_S_list(noise_idx),10*log10(nse_d),10*log10(nse_c),10*log10(nse_s),SR);

    %% %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% %%  
    %% %%%%%%%%%%%%% COMPRESSIVE CHANNEL ESTIMATION %%%%%%%%%%%%%% %%
    %% %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% %%
            N_P = 32;% number of pilots
            sparsity = 100;
            M1 = 4*M;% dense grid
            N1 = 4*N;% dense grid
            D_BS = exp(-1i*2*pi*(0:1:(M-1))'*(0:1:(M1-1))/M1);% angular dictionary
            D_RIS = exp(-1i*2*pi*(0:1:(N1-1))'*(0:1:(N-1))/N1);% angular dictionary

            Phi = exp(1i*2*pi*rand(N,N_P));
            H_c_est = zeros(M, N, K);

            %% estimate direct channel
            X = dftmtx(K)*sqrt(Tx_power_UE);
            Y = H_d*X + (normrnd(0,1,M,K) + 1i*normrnd(0,1,M,K))*sqrt(noise_power/2);
            H_d_est = Y*X'/K/Tx_power_UE;

            %% estimate cascaded channel
            Y = zeros(M,K,N_P);
            Y1 = zeros(M,K,N_P);
            for p = 1:N_P
                Y(:,:,p) = (H_d + G*diag(Phi(:,p))*H_r)*X + (normrnd(0,1,M,K) + 1i*normrnd(0,1,M,K))*sqrt(noise_power);
                Y1(:,:,p) = Y(:,:,p)*X'/K/Tx_power_UE;
            end
            for k = 1:K
                Yk = zeros(M,N_P);
                for p = 1:N_P
                    Yk(:,p) = Y1(:,k,p) - H_d_est(:,k);
                end
                S = matrixOMP(D_BS, D_RIS*Phi, Yk, sparsity, 1e-30);
                H_est_k = D_BS*S*D_RIS;
                H_c_est(:,:,k) = H_est_k;
            end

            %% evaluate the estimation error
            se_d = sum(sum(sum(abs(H_d-H_d_est).^2)));
            nse_d = se_d/sum(sum(sum(abs(H_d).^2)));
            se_c = sum(sum(sum(abs(H_c-H_c_est).^2)));
            nse_c = se_c/sum(sum(sum(abs(H_c).^2)));
            se_s = se_c + se_d;
            nse_s = se_s/(sum(sum(sum(abs(H_c).^2)))+sum(sum(sum(abs(H_d).^2))));

            [~,D,THETA] = ACE_RIS_1bit(1000,H_c_est,H_d_est,Tx_power_BS/noise_power);
            H_precoding = ((H_r.')*diag(THETA)*(G.')+(H_d.'))*(D);
            SR = 0;
            for k = 1:K
                SR = SR + log(1+Tx_power_BS/K*(abs(H_precoding(k,k))^2)/(Tx_power_BS/K*(sum(abs(H_precoding(k,:)).^2)-abs(H_precoding(k,k))^2)+noise_power));
            end

            %% store the results and display some information
            SE_H_d_CS(noise_idx,trial_idx) = se_d;
            SE_H_c_CS(noise_idx,trial_idx) = se_c;
            SR_CS(noise_idx,trial_idx) = SR;

            fprintf('|| trial = %-4d | SNR = %-2d dB |  method = CS    | NSE_d = %-6.2f dB | NSE_c = %-6.2f dB | NSE_s = %-6.2f dB | Sum Rate = %-5.2f bit/s/Hz ||\n',trial_idx,SNR_S_list(noise_idx),10*log10(nse_d),10*log10(nse_c),10*log10(nse_s),SR);
        
    %% %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% %%  
    %% %%%%%%%%%%%%%%%%% MVU CHANNEL ESTIMATION %%%%%%%%%%%%%%%%%% %%
    %% %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% %%
        Phi = dftmtx(N+1);
        A = zeros((N+1));
        H_d_est = zeros(M,K);
        H_c_est = zeros(M, N, K);
        for k = 1:K
            Y = zeros(M,N+1);
            for n = 1:N+1
                Y(:,n) = (H_d(:,k) + G*diag(Phi(2:end,n))*H_r(:,k))*sqrt(Tx_power_UE) + (normrnd(0,1,M,1) + 1i*normrnd(0,1,M,1))*sqrt(noise_power/2);
                A(n,1) = 1;
                for nn = 1:N
                    A(n,nn+1) = Phi(n,nn+1);
                end
            end
            H_est = Y*inv(A)/sqrt(Tx_power_UE);
            H_d_est(:,k) = H_est(:,1);
            H_c_est(:,:,k) = H_est(:,2:N+1);
        end
        
        %% evaluate the estimation error
        se_d = sum(sum(sum(abs(H_d-H_d_est).^2)));
        nse_d = se_d/sum(sum(sum(abs(H_d).^2)));
        se_c = sum(sum(sum(abs(H_c-H_c_est).^2)));
        nse_c = se_c/sum(sum(sum(abs(H_c).^2)));
        se_s = se_c + se_d;
        nse_s = se_s/(sum(sum(sum(abs(H_c).^2)))+sum(sum(sum(abs(H_d).^2))));
        
        [S1,D,THETA] = ACE_RIS_1bit(1000,H_c_est,H_d_est,Tx_power_BS/noise_power);
        H_precoding = ((H_r.')*diag(THETA)*(G.')+(H_d.'))*(D);
        SR = 0;
        for k = 1:K
            SR = SR + log(1+Tx_power_BS/K*(abs(H_precoding(k,k))^2)/(Tx_power_BS/K*(sum(abs(H_precoding(k,:)).^2)-abs(H_precoding(k,k))^2)+noise_power));
        end

        %% store the results and display some information
        SE_H_d_MVU(noise_idx,trial_idx) = se_d;
        SE_H_c_MVU(noise_idx,trial_idx) = se_c;
        SR_MVU(noise_idx,trial_idx) = SR;

        fprintf('|| trial = %-4d | SNR = %-2d dB |  method = MVU   | NSE_d = %-6.2f dB | NSE_c = %-6.2f dB | NSE_s = %-6.2f dB | Sum Rate = %-5.2f bit/s/Hz ||\n',trial_idx,SNR_S_list(noise_idx),10*log10(nse_d),10*log10(nse_c),10*log10(nse_s),SR);
    end
    save result.mat
end
fprintf('======================================================================================================================\n');
save result.mat



figure;
plot(SNR_S_list,mean(SR_mine,2),'r-v','LineWidth',1.5);
hold on;
p=plot(SNR_S_list,mean(SR_mine_mp,2),'r-^','LineWidth',1.5);
set(p,'Color',[0.8 0 0.8]);
hold on;
plot(SNR_S_list,mean(SR_MU,2),'b-d','LineWidth',1.5);
hold on;
p=plot(SNR_S_list,mean(SR_MVU,2),'g-o','LineWidth',1.5);
set(p,'Color',[0 0.7 0]);
hold on;
p=plot(SNR_S_list,mean(SR_CS,2),'g-*','LineWidth',1.5);
set(p,'Color',[0 0.7 0]);
hold on;
plot(SNR_S_list,mean(SR_perfectCSI,2),'k--','LineWidth',1.5,'Markersize',10);
grid on;
xlabel('SNR');
ylabel('sum rate (bit/s/Hz)');
legend('proposed method (minimum pilots)','proposed method (more pilots)','MU','MVU','perfect CSI');

figure;
semilogy(SNR_S_list,sum(SE_H_c_mine,2)/sum(Norm_H_c),'r-v','LineWidth',1.5);
hold on;
p=semilogy(SNR_S_list,sum(SE_H_c_mine_mp,2)/sum(Norm_H_c),'r-^','LineWidth',1.5);
set(p,'Color',[0.8 0 0.8]);
hold on;
semilogy(SNR_S_list,sum(SE_H_c_MU,2)/sum(Norm_H_c),'b-d','LineWidth',1.5);
hold on;
p=semilogy(SNR_S_list,sum(SE_H_c_MVU,2)/sum(Norm_H_c),'g-o','LineWidth',1.5);
set(p,'Color',[0 0.7 0]);
hold on;
p=semilogy(SNR_S_list,sum(SE_H_c_CS,2)/sum(Norm_H_c),'g-*','LineWidth',1.5);
set(p,'Color',[0 0.7 0]);
grid on;
xlabel('SNR');
ylabel('NMSE');
legend('proposed method (minimum pilots)','proposed method (more pilots)','MU','MVU');
% matlabpool close;